# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# Changes Made from original:
#   import paths
#   pruning operations
# ******************************************************************************
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Basic sequence-to-sequence model with dynamic RNN support."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import collections

import tensorflow as tf
from tensorflow.contrib.model_pruning import get_masks, get_thresholds

from tensorflow.contrib.model_pruning.python.layers import core_layers
from tensorflow.contrib.model_pruning.python import pruning

from nlp_architect.models.gnmt import model_helper
from nlp_architect.models.gnmt.utils import iterator_utils
from nlp_architect.models.gnmt.utils import misc_utils as utils

utils.check_tensorflow_version()

__all__ = ["BaseModel", "Model"]


class TrainOutputTuple(collections.namedtuple(
    "TrainOutputTuple", ("train_summary", "train_loss", "predict_count",
                         "global_step", "word_count", "batch_size", "grad_norm",
                         "learning_rate", "mask_update_op", "pruning_summary"))):
    """To allow for flexibily in returing different outputs."""
    pass


class EvalOutputTuple(collections.namedtuple(
    "EvalOutputTuple", ("eval_loss", "predict_count",
                        "batch_size"))):
    """To allow for flexibily in returing different outputs."""
    pass


class InferOutputTuple(collections.namedtuple(
    "InferOutputTuple", ("infer_logits", "infer_summary", "sample_id",
                         "sample_words"))):
    """To allow for flexibily in returing different outputs."""
    pass


class BaseModel(object):
    """Sequence-to-sequence base class.
    """

    def __init__(self,
                 hparams,
                 mode,
                 iterator,
                 source_vocab_table,
                 target_vocab_table,
                 reverse_target_vocab_table=None,
                 scope=None,
                 extra_args=None):
        """Create the model.

        Args:
          hparams: Hyperparameter configurations.
          mode: TRAIN | EVAL | INFER
          iterator: Dataset Iterator that feeds data.
          source_vocab_table: Lookup table mapping source words to ids.
          target_vocab_table: Lookup table mapping target words to ids.
          reverse_target_vocab_table: Lookup table mapping ids to target words. Only
            required in INFER mode. Defaults to None.
          scope: scope of the model.
          extra_args: model_helper.ExtraArgs, for passing customizable functions.

        """
        # Set params
        self._set_params_initializer(hparams, mode, iterator,
                                     source_vocab_table, target_vocab_table,
                                     scope, extra_args)

        # Not used in general seq2seq models; when True, ignore decoder & training
        self.extract_encoder_layers = (hasattr(hparams, "extract_encoder_layers")
                                       and hparams.extract_encoder_layers)

        # Train graph
        res = self.build_graph(hparams, scope=scope)
        if not self.extract_encoder_layers:
            self._set_train_or_infer(res, reverse_target_vocab_table, hparams)

        # Saver
        self.saver = tf.train.Saver(
            tf.global_variables(), max_to_keep=hparams.num_keep_ckpts)

    def _set_params_initializer(self,
                                hparams,
                                mode,
                                iterator,
                                source_vocab_table,
                                target_vocab_table,
                                scope,
                                extra_args=None):
        """Set various params for self and initialize."""
        assert isinstance(iterator, iterator_utils.BatchedInput)
        self.iterator = iterator
        self.mode = mode
        self.src_vocab_table = source_vocab_table
        self.tgt_vocab_table = target_vocab_table

        self.src_vocab_size = hparams.src_vocab_size
        self.tgt_vocab_size = hparams.tgt_vocab_size
        self.num_gpus = hparams.num_gpus
        self.time_major = hparams.time_major

        if hparams.use_char_encode:
            assert (not self.time_major), ("Can't use time major for"
                                           " char-level inputs.")

        self.dtype = tf.float32
        self.num_sampled_softmax = hparams.num_sampled_softmax

        # extra_args: to make it flexible for adding external customizable code
        self.single_cell_fn = None
        if extra_args:
            self.single_cell_fn = extra_args.single_cell_fn

        # Set num units
        self.num_units = hparams.num_units

        # Set num layers
        self.num_encoder_layers = hparams.num_encoder_layers
        self.num_decoder_layers = hparams.num_decoder_layers
        assert self.num_encoder_layers
        assert self.num_decoder_layers

        # Set num residual layers
        if hasattr(hparams, "num_residual_layers"):  # compatible common_test_utils
            self.num_encoder_residual_layers = hparams.num_residual_layers
            self.num_decoder_residual_layers = hparams.num_residual_layers
        else:
            self.num_encoder_residual_layers = hparams.num_encoder_residual_layers
            self.num_decoder_residual_layers = hparams.num_decoder_residual_layers

        # Batch size
        self.batch_size = tf.size(self.iterator.source_sequence_length)

        # Global step
        self.global_step = tf.Variable(0, trainable=False)

        # Initializer
        self.random_seed = hparams.random_seed
        initializer = model_helper.get_initializer(
            hparams.init_op, self.random_seed, hparams.init_weight)
        tf.get_variable_scope().set_initializer(initializer)

        # Embeddings
        if extra_args and extra_args.encoder_emb_lookup_fn:
            self.encoder_emb_lookup_fn = extra_args.encoder_emb_lookup_fn
        else:
            self.encoder_emb_lookup_fn = tf.nn.embedding_lookup
        self.init_embeddings(hparams, scope)

    def _set_train_or_infer(self, res, reverse_target_vocab_table, hparams):
        """Set up training and inference."""
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            self.train_loss = res[1]
            self.word_count = tf.reduce_sum(
                self.iterator.source_sequence_length) + tf.reduce_sum(
                self.iterator.target_sequence_length)
        elif self.mode == tf.contrib.learn.ModeKeys.EVAL:
            self.eval_loss = res[1]
        elif self.mode == tf.contrib.learn.ModeKeys.INFER:
            self.infer_logits, _, self.final_context_state, self.sample_id = res
            self.sample_words = reverse_target_vocab_table.lookup(
                tf.to_int64(self.sample_id))

        if self.mode != tf.contrib.learn.ModeKeys.INFER:
            # Count the number of predicted words for compute ppl.
            self.predict_count = tf.reduce_sum(
                self.iterator.target_sequence_length)

        params = tf.trainable_variables()

        # Gradients and SGD update operation for training the model.
        # Arrange for the embedding vars to appear at the beginning.
        if self.mode == tf.contrib.learn.ModeKeys.TRAIN:
            self.learning_rate = tf.constant(hparams.learning_rate)
            # warm-up
            self.learning_rate = self._get_learning_rate_warmup(hparams)
            # decay
            self.learning_rate = self._get_learning_rate_decay(hparams)

            # Optimizer
            if hparams.optimizer == "sgd":
                opt = tf.train.GradientDescentOptimizer(self.learning_rate)
            elif hparams.optimizer == "adam":
                opt = tf.train.AdamOptimizer(self.learning_rate)
            else:
                raise ValueError("Unknown optimizer type %s" % hparams.optimizer)

            # Gradients
            gradients = tf.gradients(
                self.train_loss,
                params,
                colocate_gradients_with_ops=hparams.colocate_gradients_with_ops)

            clipped_grads, grad_norm_summary, grad_norm = model_helper.gradient_clip(
                gradients, max_gradient_norm=hparams.max_gradient_norm)
            self.grad_norm_summary = grad_norm_summary
            self.grad_norm = grad_norm

            self.update = opt.apply_gradients(
                zip(clipped_grads, params), global_step=self.global_step)

            # Summary
            self.train_summary = self._get_train_summary()
        elif self.mode == tf.contrib.learn.ModeKeys.INFER:
            self.infer_summary = self._get_infer_summary(hparams)

        # Print trainable variables
        utils.print_out("# Trainable variables")
        utils.print_out("Format: <name>, <shape>, <(soft) device placement>")
        for param in params:
            utils.print_out("  %s, %s, %s" % (param.name, str(param.get_shape()),
                                              param.op.device))

    def _get_learning_rate_warmup(self, hparams):
        """Get learning rate warmup."""
        warmup_steps = hparams.warmup_steps
        warmup_scheme = hparams.warmup_scheme
        utils.print_out("  learning_rate=%g, warmup_steps=%d, warmup_scheme=%s" %
                        (hparams.learning_rate, warmup_steps, warmup_scheme))

        # Apply inverse decay if global steps less than warmup steps.
        # Inspired by https://arxiv.org/pdf/1706.03762.pdf (Section 5.3)
        # When step < warmup_steps,
        #   learing_rate *= warmup_factor ** (warmup_steps - step)
        if warmup_scheme == "t2t":
            # 0.01^(1/warmup_steps): we start with a lr, 100 times smaller
            warmup_factor = tf.exp(tf.log(0.01) / warmup_steps)
            inv_decay = warmup_factor ** (
                tf.to_float(warmup_steps - self.global_step))
        else:
            raise ValueError("Unknown warmup scheme %s" % warmup_scheme)

        return tf.cond(
            self.global_step < hparams.warmup_steps,
            lambda: inv_decay * self.learning_rate,
            lambda: self.learning_rate,
            name="learning_rate_warump_cond")

    def _get_decay_info(self, hparams):
        """Return decay info based on decay_scheme."""
        if hparams.decay_scheme in ["luong5", "luong10", "luong234"]:
            decay_factor = 0.5
            if hparams.decay_scheme == "luong5":
                start_decay_step = int(hparams.num_train_steps / 2)
                decay_times = 5
            elif hparams.decay_scheme == "luong10":
                start_decay_step = int(hparams.num_train_steps / 2)
                decay_times = 10
            elif hparams.decay_scheme == "luong234":
                start_decay_step = int(hparams.num_train_steps * 2 / 3)
                decay_times = 4
            remain_steps = hparams.num_train_steps - start_decay_step
            decay_steps = int(remain_steps / decay_times)
        elif not hparams.decay_scheme:  # no decay
            start_decay_step = hparams.num_train_steps
            decay_steps = 0
            decay_factor = 1.0
        elif hparams.decay_scheme:
            raise ValueError("Unknown decay scheme %s" % hparams.decay_scheme)
        return start_decay_step, decay_steps, decay_factor

    def _get_learning_rate_decay(self, hparams):
        """Get learning rate decay."""
        start_decay_step, decay_steps, decay_factor = self._get_decay_info(hparams)
        utils.print_out("  decay_scheme=%s, start_decay_step=%d, decay_steps %d, "
                        "decay_factor %g" % (hparams.decay_scheme,
                                             start_decay_step,
                                             decay_steps,
                                             decay_factor))

        return tf.cond(
            self.global_step < start_decay_step,
            lambda: self.learning_rate,
            lambda: tf.train.exponential_decay(
                self.learning_rate,
                (self.global_step - start_decay_step),
                decay_steps, decay_factor, staircase=True),
            name="learning_rate_decay_cond")

    def init_embeddings(self, hparams, scope):
        """Init embeddings."""
        self.embedding_encoder, self.embedding_decoder = (
            model_helper.create_emb_for_encoder_and_decoder(
                share_vocab=hparams.share_vocab,
                src_vocab_size=self.src_vocab_size,
                tgt_vocab_size=self.tgt_vocab_size,
                src_embed_size=self.num_units,
                tgt_embed_size=self.num_units,
                num_enc_partitions=hparams.num_enc_emb_partitions,
                num_dec_partitions=hparams.num_dec_emb_partitions,
                src_vocab_file=hparams.src_vocab_file,
                tgt_vocab_file=hparams.tgt_vocab_file,
                src_embed_file=hparams.src_embed_file,
                tgt_embed_file=hparams.tgt_embed_file,
                use_char_encode=hparams.use_char_encode,
                scope=scope,
                embed_type=hparams.embedding_type, ))

    def _get_train_summary(self):
        """Get train summary."""
        train_summary = tf.summary.merge(
            [tf.summary.scalar("lr", self.learning_rate),
             tf.summary.scalar("train_loss", self.train_loss)] + self.grad_norm_summary)
        return train_summary

    def train(self, sess):
        """Execute train graph."""
        assert self.mode == tf.contrib.learn.ModeKeys.TRAIN
        output_tuple = TrainOutputTuple(train_summary=self.train_summary,
                                        train_loss=self.train_loss,
                                        predict_count=self.predict_count,
                                        global_step=self.global_step,
                                        word_count=self.word_count,
                                        batch_size=self.batch_size,
                                        grad_norm=self.grad_norm,
                                        learning_rate=self.learning_rate,
                                        mask_update_op=self.mask_update_op,
                                        pruning_summary=self.pruning_summary)
        return sess.run([self.update, output_tuple])

    def eval(self, sess):
        """Execute eval graph."""
        assert self.mode == tf.contrib.learn.ModeKeys.EVAL
        output_tuple = EvalOutputTuple(eval_loss=self.eval_loss,
                                       predict_count=self.predict_count,
                                       batch_size=self.batch_size)
        return sess.run(output_tuple)

    def build_graph(self, hparams, scope=None):
        """Subclass must implement this method.

        Creates a sequence-to-sequence model with dynamic RNN decoder API.
        Args:
          hparams: Hyperparameter configurations.
          scope: VariableScope for the created subgraph; default "dynamic_seq2seq".

        Returns:
          A tuple of the form (logits, loss_tuple, final_context_state, sample_id),
          where:
            logits: float32 Tensor [batch_size x num_decoder_symbols].
            loss: loss = the total loss / batch_size.
            final_context_state: the final state of decoder RNN.
            sample_id: sampling indices.

        Raises:
          ValueError: if encoder_type differs from mono and bi, or
            attention_option is not (luong | scaled_luong |
            bahdanau | normed_bahdanau).
        """
        utils.print_out("# Creating %s graph ..." % self.mode)

        # Projection
        if not self.extract_encoder_layers:
            with tf.variable_scope(scope or "build_network"):
                with tf.variable_scope("decoder/output_projection"):
                    if hparams.projection_type == 'sparse':
                        self.output_layer = core_layers.MaskedFullyConnected(
                            hparams.tgt_vocab_size, use_bias=False, name="output_projection")
                    elif hparams.projection_type == 'dense':
                        self.output_layer = tf.layers.Dense(
                            hparams.tgt_vocab_size, use_bias=False, name="output_projection")
                    else:
                        raise ValueError("Unknown projection type %s!" % hparams.projection_type)

        with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype):
            # Encoder
            if hparams.language_model:  # no encoder for language modeling
                utils.print_out("  language modeling: no encoder")
                self.encoder_outputs = None
                encoder_state = None
            else:
                self.encoder_outputs, encoder_state = self._build_encoder(hparams)

            # Skip decoder if extracting only encoder layers
            if self.extract_encoder_layers:
                return

            # Decoder
            logits, decoder_cell_outputs, sample_id, final_context_state = (
                self._build_decoder(self.encoder_outputs, encoder_state, hparams))

            # Loss
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                with tf.device(model_helper.get_device_str(self.num_encoder_layers - 1,
                                                           self.num_gpus)):
                    loss = self._compute_loss(logits, decoder_cell_outputs)
            else:
                loss = tf.constant(0.0)

            # model pruning
            if hparams.pruning_hparams is not None:
                pruning_hparams = pruning.get_pruning_hparams().parse(
                    hparams.pruning_hparams)
                self.p = pruning.Pruning(pruning_hparams, global_step=self.global_step)
                self.mask_update_op = self.p.conditional_mask_update_op()
                masks = get_masks()
                thresholds = get_thresholds()
                masks_s = []
                for index, mask in enumerate(masks):
                    masks_s.append(tf.summary.scalar(mask.name + '/sparsity',
                                                     tf.nn.zero_fraction(mask)))
                    masks_s.append(tf.summary.scalar(thresholds[index].op.name + '/threshold',
                                                     thresholds[index]))
                    masks_s.append(tf.summary.histogram(mask.name + '/mask_tensor', mask))
                self.pruning_summary = tf.summary.merge(
                    [tf.summary.scalar('sparsity', self.p._sparsity),
                     tf.summary.scalar('last_mask_update_step', self.p._last_update_step)]
                    + masks_s)
            else:
                self.mask_update_op = tf.no_op()
                self.pruning_summary = tf.no_op()

            return logits, loss, final_context_state, sample_id

    @abc.abstractmethod
    def _build_encoder(self, hparams):
        """Subclass must implement this.

        Build and run an RNN encoder.

        Args:
          hparams: Hyperparameters configurations.

        Returns:
          A tuple of encoder_outputs and encoder_state.
        """
        pass

    def _build_encoder_cell(self, hparams, num_layers, num_residual_layers,
                            base_gpu=0):
        """Build a multi-layer RNN cell that can be used by encoder."""

        return model_helper.create_rnn_cell(
            unit_type=hparams.unit_type,
            num_units=self.num_units,
            num_layers=num_layers,
            num_residual_layers=num_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=hparams.num_gpus,
            mode=self.mode,
            base_gpu=base_gpu,
            single_cell_fn=self.single_cell_fn)

    def _get_infer_maximum_iterations(self, hparams, source_sequence_length):
        """Maximum decoding steps at inference time."""
        if hparams.tgt_max_len_infer:
            maximum_iterations = hparams.tgt_max_len_infer
            utils.print_out("  decoding maximum_iterations %d" % maximum_iterations)
        else:
            # TODO(thangluong): add decoding_length_factor flag
            decoding_length_factor = 2.0
            max_encoder_length = tf.reduce_max(source_sequence_length)
            maximum_iterations = tf.to_int32(tf.round(
                tf.to_float(max_encoder_length) * decoding_length_factor))
        return maximum_iterations

    def _build_decoder(self, encoder_outputs, encoder_state, hparams):
        """Build and run a RNN decoder with a final projection layer.

        Args:
          encoder_outputs: The outputs of encoder for every time step.
          encoder_state: The final state of the encoder.
          hparams: The Hyperparameters configurations.

        Returns:
          A tuple of final logits and final decoder state:
            logits: size [time, batch_size, vocab_size] when time_major=True.
        """
        tgt_sos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.sos)),
                             tf.int32)
        tgt_eos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.eos)),
                             tf.int32)
        iterator = self.iterator

        # maximum_iteration: The maximum decoding steps.
        maximum_iterations = self._get_infer_maximum_iterations(
            hparams, iterator.source_sequence_length)

        # Decoder.
        with tf.variable_scope("decoder") as decoder_scope:
            cell, decoder_initial_state = self._build_decoder_cell(
                hparams, encoder_outputs, encoder_state,
                iterator.source_sequence_length)

            # Optional ops depends on which mode we are in and which loss function we
            # are using.
            logits = tf.no_op()
            decoder_cell_outputs = None

            # Train or eval
            if self.mode != tf.contrib.learn.ModeKeys.INFER:
                # decoder_emp_inp: [max_time, batch_size, num_units]
                target_input = iterator.target_input
                if self.time_major:
                    target_input = tf.transpose(target_input)
                decoder_emb_inp = tf.nn.embedding_lookup(
                    self.embedding_decoder, target_input)

                # Helper
                helper = tf.contrib.seq2seq.TrainingHelper(
                    decoder_emb_inp, iterator.target_sequence_length,
                    time_major=self.time_major)

                # Decoder
                my_decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell,
                    helper,
                    decoder_initial_state, )

                # Dynamic decoding
                outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(
                    my_decoder,
                    output_time_major=self.time_major,
                    swap_memory=True,
                    scope=decoder_scope)

                sample_id = outputs.sample_id

                if self.num_sampled_softmax > 0:
                    # Note: this is required when using sampled_softmax_loss.
                    decoder_cell_outputs = outputs.rnn_output

                # Note: there's a subtle difference here between train and inference.
                # We could have set output_layer when create my_decoder
                #   and shared more code between train and inference.
                # We chose to apply the output_layer to all timesteps for speed:
                #   10% improvements for small models & 20% for larger ones.
                # If memory is a concern, we should apply output_layer per timestep.
                num_layers = self.num_decoder_layers
                num_gpus = self.num_gpus
                device_id = num_layers if num_layers < num_gpus else (num_layers - 1)
                # Colocate output layer with the last RNN cell if there is no extra GPU
                # available. Otherwise, put last layer on a separate GPU.
                with tf.device(model_helper.get_device_str(device_id, num_gpus)):
                    logits = self.output_layer(outputs.rnn_output)

                if self.num_sampled_softmax > 0:
                    logits = tf.no_op()  # unused when using sampled softmax loss.

            # Inference
            else:
                infer_mode = hparams.infer_mode
                start_tokens = tf.fill([self.batch_size], tgt_sos_id)
                end_token = tgt_eos_id
                utils.print_out(
                    "  decoder: infer_mode=%sbeam_width=%d, length_penalty=%f" % (
                        infer_mode, hparams.beam_width, hparams.length_penalty_weight))

                if infer_mode == "beam_search":
                    beam_width = hparams.beam_width
                    length_penalty_weight = hparams.length_penalty_weight

                    my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                        cell=cell,
                        embedding=self.embedding_decoder,
                        start_tokens=start_tokens,
                        end_token=end_token,
                        initial_state=decoder_initial_state,
                        beam_width=beam_width,
                        output_layer=self.output_layer,
                        length_penalty_weight=length_penalty_weight)
                elif infer_mode == "sample":
                    # Helper
                    sampling_temperature = hparams.sampling_temperature
                    assert sampling_temperature > 0.0, (
                        "sampling_temperature must greater than 0.0 when using sample"
                        " decoder.")
                    helper = tf.contrib.seq2seq.SampleEmbeddingHelper(
                        self.embedding_decoder, start_tokens, end_token,
                        softmax_temperature=sampling_temperature,
                        seed=self.random_seed)
                elif infer_mode == "greedy":
                    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                        self.embedding_decoder, start_tokens, end_token)
                else:
                    raise ValueError("Unknown infer_mode '%s'", infer_mode)

                if infer_mode != "beam_search":
                    my_decoder = tf.contrib.seq2seq.BasicDecoder(
                        cell,
                        helper,
                        decoder_initial_state,
                        output_layer=self.output_layer  # applied per timestep
                    )

                # Dynamic decoding
                outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(
                    my_decoder,
                    maximum_iterations=maximum_iterations,
                    output_time_major=self.time_major,
                    swap_memory=True,
                    scope=decoder_scope)

                if infer_mode == "beam_search":
                    sample_id = outputs.predicted_ids
                else:
                    logits = outputs.rnn_output
                    sample_id = outputs.sample_id

        return logits, decoder_cell_outputs, sample_id, final_context_state

    def get_max_time(self, tensor):
        time_axis = 0 if self.time_major else 1
        return tensor.shape[time_axis].value or tf.shape(tensor)[time_axis]

    @abc.abstractmethod
    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                            source_sequence_length):
        """Subclass must implement this.

        Args:
          hparams: Hyperparameters configurations.
          encoder_outputs: The outputs of encoder for every time step.
          encoder_state: The final state of the encoder.
          source_sequence_length: sequence length of encoder_outputs.

        Returns:
          A tuple of a multi-layer RNN cell used by decoder and the intial state of
          the decoder RNN.
        """
        pass

    def _softmax_cross_entropy_loss(
            self, logits, decoder_cell_outputs, labels):
        """Compute softmax loss or sampled softmax loss."""
        if self.num_sampled_softmax > 0:

            is_sequence = (decoder_cell_outputs.shape.ndims == 3)

            if is_sequence:
                labels = tf.reshape(labels, [-1, 1])
                inputs = tf.reshape(decoder_cell_outputs, [-1, self.num_units])

            crossent = tf.nn.sampled_softmax_loss(
                weights=tf.transpose(self.output_layer.kernel),
                biases=self.output_layer.bias or tf.zeros([self.tgt_vocab_size]),
                labels=labels,
                inputs=inputs,
                num_sampled=self.num_sampled_softmax,
                num_classes=self.tgt_vocab_size,
                partition_strategy="div",
                seed=self.random_seed)

            if is_sequence:
                if self.time_major:
                    crossent = tf.reshape(crossent, [-1, self.batch_size])
                else:
                    crossent = tf.reshape(crossent, [self.batch_size, -1])

        else:
            crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=labels, logits=logits)

        return crossent

    def _compute_loss(self, logits, decoder_cell_outputs):
        """Compute optimization loss."""
        target_output = self.iterator.target_output
        if self.time_major:
            target_output = tf.transpose(target_output)
        max_time = self.get_max_time(target_output)

        crossent = self._softmax_cross_entropy_loss(
            logits, decoder_cell_outputs, target_output)

        target_weights = tf.sequence_mask(
            self.iterator.target_sequence_length, max_time, dtype=self.dtype)
        if self.time_major:
            target_weights = tf.transpose(target_weights)

        loss = tf.reduce_sum(
            crossent * target_weights) / tf.to_float(self.batch_size)
        return loss

    def _get_infer_summary(self, hparams):
        del hparams
        return tf.no_op()

    def infer(self, sess):
        assert self.mode == tf.contrib.learn.ModeKeys.INFER
        output_tuple = InferOutputTuple(infer_logits=self.infer_logits,
                                        infer_summary=self.infer_summary,
                                        sample_id=self.sample_id,
                                        sample_words=self.sample_words)
        return sess.run(output_tuple)

    def decode(self, sess):
        """Decode a batch.

        Args:
          sess: tensorflow session to use.

        Returns:
          A tuple consiting of outputs, infer_summary.
            outputs: of size [batch_size, time]
        """
        output_tuple = self.infer(sess)
        sample_words = output_tuple.sample_words
        infer_summary = output_tuple.infer_summary

        # make sure outputs is of shape [batch_size, time] or [beam_width,
        # batch_size, time] when using beam search.
        if self.time_major:
            sample_words = sample_words.transpose()
        elif sample_words.ndim == 3:
            # beam search output in [batch_size, time, beam_width] shape.
            sample_words = sample_words.transpose([2, 0, 1])
        return sample_words, infer_summary

    def build_encoder_states(self, include_embeddings=False):
        """Stack encoder states and return tensor [batch, length, layer, size]."""
        assert self.mode == tf.contrib.learn.ModeKeys.INFER
        if include_embeddings:
            stack_state_list = tf.stack(
                [self.encoder_emb_inp] + self.encoder_state_list, 2)
        else:
            stack_state_list = tf.stack(self.encoder_state_list, 2)

        # transform from [length, batch, ...] -> [batch, length, ...]
        if self.time_major:
            stack_state_list = tf.transpose(stack_state_list, [1, 0, 2, 3])

        return stack_state_list


class Model(BaseModel):
    """Sequence-to-sequence dynamic model.

    This class implements a multi-layer recurrent neural network as encoder,
    and a multi-layer recurrent neural network decoder.
    """

    def _build_encoder_from_sequence(self, hparams, sequence, sequence_length):
        """Build an encoder from a sequence.

        Args:
          hparams: hyperparameters.
          sequence: tensor with input sequence data.
          sequence_length: tensor with length of the input sequence.

        Returns:
          encoder_outputs: RNN encoder outputs.
          encoder_state: RNN encoder state.

        Raises:
          ValueError: if encoder_type is neither "uni" nor "bi".
        """
        num_layers = self.num_encoder_layers
        num_residual_layers = self.num_encoder_residual_layers

        if self.time_major:
            sequence = tf.transpose(sequence)

        with tf.variable_scope("encoder") as scope:
            dtype = scope.dtype

            self.encoder_emb_inp = self.encoder_emb_lookup_fn(
                self.embedding_encoder, sequence)

            # Encoder_outputs: [max_time, batch_size, num_units]
            if hparams.encoder_type == "uni":
                utils.print_out("  num_layers = %d, num_residual_layers=%d" %
                                (num_layers, num_residual_layers))
                cell = self._build_encoder_cell(hparams, num_layers,
                                                num_residual_layers)

                encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
                    cell,
                    self.encoder_emb_inp,
                    dtype=dtype,
                    sequence_length=sequence_length,
                    time_major=self.time_major,
                    swap_memory=True)
            elif hparams.encoder_type == "bi":
                num_bi_layers = int(num_layers / 2)
                num_bi_residual_layers = int(num_residual_layers / 2)
                utils.print_out("  num_bi_layers = %d, num_bi_residual_layers=%d" %
                                (num_bi_layers, num_bi_residual_layers))

                encoder_outputs, bi_encoder_state = (
                    self._build_bidirectional_rnn(
                        inputs=self.encoder_emb_inp,
                        sequence_length=sequence_length,
                        dtype=dtype,
                        hparams=hparams,
                        num_bi_layers=num_bi_layers,
                        num_bi_residual_layers=num_bi_residual_layers))

                if num_bi_layers == 1:
                    encoder_state = bi_encoder_state
                else:
                    # alternatively concat forward and backward states
                    encoder_state = []
                    for layer_id in range(num_bi_layers):
                        encoder_state.append(bi_encoder_state[0][layer_id])  # forward
                        encoder_state.append(bi_encoder_state[1][layer_id])  # backward
                    encoder_state = tuple(encoder_state)
            else:
                raise ValueError("Unknown encoder_type %s" % hparams.encoder_type)

        # Use the top layer for now
        self.encoder_state_list = [encoder_outputs]

        return encoder_outputs, encoder_state

    def _build_encoder(self, hparams):
        """Build encoder from source."""
        utils.print_out("# Build a basic encoder")
        return self._build_encoder_from_sequence(
            hparams, self.iterator.source, self.iterator.source_sequence_length)

    def _build_bidirectional_rnn(self, inputs, sequence_length,
                                 dtype, hparams,
                                 num_bi_layers,
                                 num_bi_residual_layers,
                                 base_gpu=0):
        """Create and call biddirectional RNN cells.

        Args:
          num_residual_layers: Number of residual layers from top to bottom. For
            example, if `num_bi_layers=4` and `num_residual_layers=2`, the last 2 RNN
            layers in each RNN cell will be wrapped with `ResidualWrapper`.
          base_gpu: The gpu device id to use for the first forward RNN layer. The
            i-th forward RNN layer will use `(base_gpu + i) % num_gpus` as its
            device id. The `base_gpu` for backward RNN cell is `(base_gpu +
            num_bi_layers)`.

        Returns:
          The concatenated bidirectional output and the bidirectional RNN cell"s
          state.
        """
        # Construct forward and backward cells
        fw_cell = self._build_encoder_cell(hparams,
                                           num_bi_layers,
                                           num_bi_residual_layers,
                                           base_gpu=base_gpu)
        bw_cell = self._build_encoder_cell(hparams,
                                           num_bi_layers,
                                           num_bi_residual_layers,
                                           base_gpu=(base_gpu + num_bi_layers))

        bi_outputs, bi_state = tf.nn.bidirectional_dynamic_rnn(
            fw_cell,
            bw_cell,
            inputs,
            dtype=dtype,
            sequence_length=sequence_length,
            time_major=self.time_major,
            swap_memory=True)

        return tf.concat(bi_outputs, -1), bi_state

    def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                            source_sequence_length, base_gpu=0):
        """Build an RNN cell that can be used by decoder."""
        # We only make use of encoder_outputs in attention-based models
        if hparams.attention:
            raise ValueError("BasicModel doesn't support attention.")

        cell = model_helper.create_rnn_cell(
            unit_type=hparams.unit_type,
            num_units=self.num_units,
            num_layers=self.num_decoder_layers,
            num_residual_layers=self.num_decoder_residual_layers,
            forget_bias=hparams.forget_bias,
            dropout=hparams.dropout,
            num_gpus=self.num_gpus,
            mode=self.mode,
            single_cell_fn=self.single_cell_fn,
            base_gpu=base_gpu
        )

        if hparams.language_model:
            encoder_state = cell.zero_state(self.batch_size, self.dtype)
        elif not hparams.pass_hidden_state:
            raise ValueError("For non-attentional model, "
                             "pass_hidden_state needs to be set to True")

        # For beam search, we need to replicate encoder infos beam_width times
        if self.mode == tf.contrib.learn.ModeKeys.INFER and hparams.infer_mode == "beam_search":
            decoder_initial_state = tf.contrib.seq2seq.tile_batch(
                encoder_state, multiplier=hparams.beam_width)
        else:
            decoder_initial_state = encoder_state

        return cell, decoder_initial_state
